"""
Phase 3: Survival Analysis
Development Threshold Paper

Kaplan-Meier curves, Cox proportional hazard, competing risks.
Uses panel structure: time-to-exit from $9k-$25k zone.
"""

import pandas as pd
import numpy as np
from scipy import stats
import sys
sys.path.insert(0, '/mnt/c/demographics_capital_flows/multilateral/src')

# ═══════════════════════════════════════════════════════════════════════
# Load data
# ═══════════════════════════════════════════════════════════════════════
full = pd.read_csv('/mnt/c/demographics_capital_flows/multilateral/140_country/data/processed/full_panel.csv')
panel = full[full['year'] <= 2024].copy()

try:
    rents = pd.read_csv('/mnt/c/demographics_capital_flows/multilateral/140_country/data/raw/wdi_resource_rents.csv')
    panel = panel.merge(rents[['iso3', 'year', 'resource_rents_gdp']], on=['iso3', 'year'], how='left')
except:
    panel['resource_rents_gdp'] = np.nan

cross_df = pd.read_csv('/mnt/c/demographics_capital_flows/development_threshold/data/crossing_data.csv')

LOWER = 9000
UPPER = 25000

# ═══════════════════════════════════════════════════════════════════════
# 3a. Build survival dataset
# ═══════════════════════════════════════════════════════════════════════
print("=" * 70)
print("3a. BUILDING SURVIVAL DATASET")
print("=" * 70)

# For each country that entered the zone, track time until exit (above or below)
# Duration = years from zone entry
# Event = 1 if exited above $25k, 0 if censored (still in zone or fell back)

zone_entrants = cross_df[cross_df['cross_9k'].notna() & (cross_df['status'] != 'Always above')].copy()

survival_records = []
for _, row in zone_entrants.iterrows():
    iso = row['iso3']
    entry_year = row['cross_9k']

    c = panel[(panel['iso3'] == iso) & (panel['year'] >= entry_year)].sort_values('year')
    gdp_s = c[c['gdp_pc_ppp'].notna()]

    if len(gdp_s) == 0:
        continue

    # Find first exit above $25k
    above = gdp_s[gdp_s['gdp_pc_ppp'] >= UPPER]
    below = gdp_s[gdp_s['year'] > entry_year][gdp_s[gdp_s['year'] > entry_year]['gdp_pc_ppp'] < LOWER] if len(gdp_s) > 0 else pd.DataFrame()

    if len(above) > 0:
        exit_year = above['year'].min()
        duration = exit_year - entry_year
        event = 1  # Exited above
        exit_type = 'above'
    elif len(below) > 0:
        exit_year = below['year'].min()
        duration = exit_year - entry_year
        event = 0  # Fell back (censored for "exit above" analysis)
        exit_type = 'below'
    else:
        exit_year = gdp_s['year'].max()
        duration = exit_year - entry_year
        event = 0  # Still in zone (censored)
        exit_type = 'censored'

    survival_records.append({
        'iso3': iso,
        'entry_year': entry_year,
        'exit_year': exit_year,
        'duration': duration,
        'event': event,
        'exit_type': exit_type,
        'Z_1_entry': row.get('Z_1_entry'),
        'old_dep_entry': row.get('old_dep_entry'),
        'youth_dep_entry': row.get('youth_dep_entry'),
        'kaopen_entry': row.get('kaopen_entry'),
        'gdp_pc_ppp_entry': row.get('gdp_pc_ppp_entry'),
        'gross_savings_gdp_entry': row.get('gross_savings_gdp_entry'),
        'resource_rents_gdp_entry': row.get('resource_rents_gdp_entry'),
        'rgdp_growth_entry': row.get('rgdp_growth_entry'),
        'trade_openness_entry': row.get('trade_openness_entry'),
    })

surv = pd.DataFrame(survival_records)
surv.to_csv('/mnt/c/demographics_capital_flows/development_threshold/data/survival_data.csv', index=False)

print(f"Survival dataset: {len(surv)} countries")
print(f"  Events (exited above): {(surv['event']==1).sum()}")
print(f"  Censored: {(surv['event']==0).sum()}")
print(f"    Still in zone: {(surv['exit_type']=='censored').sum()}")
print(f"    Fell back: {(surv['exit_type']=='below').sum()}")
print(f"  Duration range: {surv['duration'].min():.0f} to {surv['duration'].max():.0f} years")

# ═══════════════════════════════════════════════════════════════════════
# 3b. Kaplan-Meier survival estimates by demographic tercile
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("3b. KAPLAN-MEIER SURVIVAL BY DEMOGRAPHIC TERCILE")
print("=" * 70)

def kaplan_meier(durations, events):
    """Simple Kaplan-Meier estimator."""
    times = sorted(set(durations))
    n_at_risk = len(durations)
    km_times = [0]
    km_survival = [1.0]

    for t in times:
        d = sum((durations == t) & (events == 1))  # deaths at t
        c = sum((durations == t) & (events == 0))  # censored at t
        if n_at_risk > 0:
            s = 1 - d / n_at_risk
            km_times.append(t)
            km_survival.append(km_survival[-1] * s)
            n_at_risk -= (d + c)
        if n_at_risk <= 0:
            break

    return km_times, km_survival

# Split by Z₁ tercile at entry
surv_z1 = surv[surv['Z_1_entry'].notna()].copy()
surv_z1['z1_tercile'] = pd.qcut(surv_z1['Z_1_entry'], 3, labels=['Young', 'Middle', 'Old'])

print("\nMedian survival time (years to exit above $25k):")
for tercile in ['Young', 'Middle', 'Old']:
    sub = surv_z1[surv_z1['z1_tercile'] == tercile]
    times, survival = kaplan_meier(sub['duration'].values, sub['event'].values)
    # Median: first time survival drops below 0.5
    median_t = None
    for t, s in zip(times, survival):
        if s <= 0.5:
            median_t = t
            break
    med_str = f"{median_t:.0f}" if median_t is not None else ">censored"
    n_events = sub['event'].sum()
    print(f"  {tercile:8s}: n={len(sub)}, events={n_events}, median survival={med_str} years, "
          f"Z₁ range=[{sub['Z_1_entry'].min():.2f}, {sub['Z_1_entry'].max():.2f}]")

# Log-rank test between young and old terciles
young = surv_z1[surv_z1['z1_tercile'] == 'Young']
old = surv_z1[surv_z1['z1_tercile'] == 'Old']

# Manual log-rank test
def log_rank_test(dur1, ev1, dur2, ev2):
    """Two-sample log-rank test."""
    all_times = sorted(set(np.concatenate([dur1, dur2])))
    chi2 = 0
    var = 0
    n1 = len(dur1)
    n2 = len(dur2)
    at_risk_1 = n1
    at_risk_2 = n2

    for t in all_times:
        d1 = sum((dur1 == t) & (ev1 == 1))
        d2 = sum((dur2 == t) & (ev2 == 1))
        c1 = sum((dur1 == t) & (ev1 == 0))
        c2 = sum((dur2 == t) & (ev2 == 0))
        d = d1 + d2
        n = at_risk_1 + at_risk_2

        if n > 1 and d > 0:
            e1 = at_risk_1 * d / n  # expected events in group 1
            chi2 += (d1 - e1)
            var += e1 * (1 - at_risk_1 / n) * (n - d) / (n - 1) if n > 1 else 0

        at_risk_1 -= (d1 + c1)
        at_risk_2 -= (d2 + c2)
        if at_risk_1 <= 0 and at_risk_2 <= 0:
            break

    if var > 0:
        test_stat = chi2**2 / var
        p_value = 1 - stats.chi2.cdf(test_stat, 1)
        return test_stat, p_value
    return np.nan, np.nan

lr_stat, lr_p = log_rank_test(young['duration'].values, young['event'].values,
                               old['duration'].values, old['event'].values)
print(f"\nLog-rank test (Young vs Old): χ²={lr_stat:.3f}, p={lr_p:.4f}")

# By KAOPEN tercile
surv_kao = surv[surv['kaopen_entry'].notna()].copy()
surv_kao['kao_tercile'] = pd.qcut(surv_kao['kaopen_entry'], 3, labels=['Closed', 'Middle', 'Open'],
                                   duplicates='drop')

print("\nBy KAOPEN tercile:")
for tercile in ['Closed', 'Middle', 'Open']:
    sub = surv_kao[surv_kao['kao_tercile'] == tercile]
    if len(sub) == 0:
        continue
    times, survival = kaplan_meier(sub['duration'].values, sub['event'].values)
    median_t = None
    for t, s in zip(times, survival):
        if s <= 0.5:
            median_t = t
            break
    med_str = f"{median_t:.0f}" if median_t is not None else ">censored"
    print(f"  {tercile:8s}: n={len(sub)}, events={sub['event'].sum()}, median={med_str}")

# By resource exporter status
surv_res = surv[surv['resource_rents_gdp_entry'].notna()].copy()
surv_res['commodity'] = (surv_res['resource_rents_gdp_entry'] >= 10).astype(int)

print("\nBy commodity exporter status:")
for label, val in [('Non-commodity', 0), ('Commodity', 1)]:
    sub = surv_res[surv_res['commodity'] == val]
    if len(sub) == 0:
        continue
    times, survival = kaplan_meier(sub['duration'].values, sub['event'].values)
    median_t = None
    for t, s in zip(times, survival):
        if s <= 0.5:
            median_t = t
            break
    med_str = f"{median_t:.0f}" if median_t is not None else ">censored"
    print(f"  {label:15s}: n={len(sub)}, events={sub['event'].sum()}, median={med_str}")

# ═══════════════════════════════════════════════════════════════════════
# 3c. Cox Proportional Hazard (using statsmodels PHReg)
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("3c. COX PROPORTIONAL HAZARD MODELS")
print("=" * 70)

from statsmodels.duration.hazard_regression import PHReg

cox_vars = ['Z_1_entry', 'gdp_pc_ppp_entry', 'kaopen_entry', 'resource_rents_gdp_entry']
cox_sample = surv[['duration', 'event'] + cox_vars].dropna()
cox_sample = cox_sample[cox_sample['duration'] > 0]  # Cox requires positive durations

# Standardize GDP for interpretability
cox_sample['gdp_pc_ppp_entry_k'] = cox_sample['gdp_pc_ppp_entry'] / 1000

print(f"\nCox sample: {len(cox_sample)} countries, {cox_sample['event'].sum()} events")

# Model 1: Z₁ only
try:
    cox1 = PHReg(cox_sample['duration'], cox_sample[['Z_1_entry']],
                 status=cox_sample['event']).fit()
    print(f"\n  Cox M1: Z₁ only")
    print(f"    Z₁: HR={np.exp(cox1.params[0]):.3f}, coef={cox1.params[0]:.4f}, p={cox1.pvalues[0]:.4f}")
except Exception as e:
    print(f"  Cox M1 failed: {e}")

# Model 2: Z₁ + income
try:
    cox2 = PHReg(cox_sample['duration'], cox_sample[['Z_1_entry', 'gdp_pc_ppp_entry_k']],
                 status=cox_sample['event']).fit()
    print(f"\n  Cox M2: Z₁ + income")
    for i, var in enumerate(['Z_1_entry', 'gdp_pc_ppp_entry_k']):
        hr = np.exp(cox2.params[i])
        sig = '***' if cox2.pvalues[i] < 0.01 else '**' if cox2.pvalues[i] < 0.05 else '*' if cox2.pvalues[i] < 0.1 else ''
        print(f"    {var}: HR={hr:.3f}, coef={cox2.params[i]:.4f}, p={cox2.pvalues[i]:.4f}{sig}")
except Exception as e:
    print(f"  Cox M2 failed: {e}")

# Model 3: Full
cox_full_vars = ['Z_1_entry', 'gdp_pc_ppp_entry_k', 'kaopen_entry', 'resource_rents_gdp_entry']
cox_full_sample = cox_sample[['duration', 'event'] + [v for v in cox_full_vars if v in cox_sample.columns]].dropna()
cox_full_sample = cox_full_sample[cox_full_sample['duration'] > 0]

try:
    available_cox = [v for v in cox_full_vars if v in cox_full_sample.columns]
    cox3 = PHReg(cox_full_sample['duration'], cox_full_sample[available_cox],
                 status=cox_full_sample['event']).fit()
    print(f"\n  Cox M3: Full model (n={len(cox_full_sample)})")
    cox_results = []
    for i, var in enumerate(available_cox):
        hr = np.exp(cox3.params[i])
        sig = '***' if cox3.pvalues[i] < 0.01 else '**' if cox3.pvalues[i] < 0.05 else '*' if cox3.pvalues[i] < 0.1 else ''
        print(f"    {var}: HR={hr:.3f}, coef={cox3.params[i]:.4f}, se={cox3.bse[i]:.4f}, p={cox3.pvalues[i]:.4f}{sig}")
        cox_results.append({
            'variable': var, 'hazard_ratio': hr, 'coefficient': cox3.params[i],
            'std_error': cox3.bse[i], 'p_value': cox3.pvalues[i]
        })
    pd.DataFrame(cox_results).to_csv(
        '/mnt/c/demographics_capital_flows/development_threshold/output/tables/table9_cox.csv', index=False)
except Exception as e:
    print(f"  Cox M3 failed: {e}")

# ═══════════════════════════════════════════════════════════════════════
# 3d. Competing risks: exit above vs fall below
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("3d. COMPETING RISKS: EXIT ABOVE vs FALL BELOW")
print("=" * 70)

# Cumulative incidence functions (simple approach)
# Event 1: exit above $25k; Event 2: fall below $9k
surv['event_above'] = (surv['exit_type'] == 'above').astype(int)
surv['event_below'] = (surv['exit_type'] == 'below').astype(int)

n_above = surv['event_above'].sum()
n_below = surv['event_below'].sum()
n_censored = (surv['exit_type'] == 'censored').sum()

print(f"\nCompeting risks summary:")
print(f"  Exited above $25k: {n_above} ({n_above/len(surv)*100:.1f}%)")
print(f"  Fell below $9k: {n_below} ({n_below/len(surv)*100:.1f}%)")
print(f"  Still in zone: {n_censored} ({n_censored/len(surv)*100:.1f}%)")

# Compare characteristics by exit type
for exit_type in ['above', 'below', 'censored']:
    sub = surv[surv['exit_type'] == exit_type]
    if len(sub) > 0:
        z1 = sub['Z_1_entry'].mean()
        gdp = sub['gdp_pc_ppp_entry'].mean()
        print(f"  {exit_type:10s}: n={len(sub)}, mean Z₁={z1:.3f}, mean GDP/cap entry=${gdp:,.0f}")

# Test: does Z₁ distinguish exit-above from exit-below?
above_z1 = surv[surv['exit_type'] == 'above']['Z_1_entry'].dropna()
below_z1 = surv[surv['exit_type'] == 'below']['Z_1_entry'].dropna()
if len(above_z1) > 2 and len(below_z1) > 2:
    t, p = stats.ttest_ind(above_z1, below_z1, equal_var=False)
    print(f"\n  Z₁ at entry: exit-above={above_z1.mean():.3f} vs fell-below={below_z1.mean():.3f}, p={p:.4f}")

# ═══════════════════════════════════════════════════════════════════════
# 3e. Demographic acceleration during transit
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("3e. DEMOGRAPHIC CHANGE DURING TRANSIT")
print("=" * 70)

# Does Z₁ change during transit predict exit outcome?
surv_dz1 = surv.merge(
    cross_df[['iso3', 'Z_1_change_in_zone']].dropna(),
    on='iso3', how='left'
)

above_dz1 = surv_dz1[surv_dz1['exit_type'] == 'above']['Z_1_change_in_zone'].dropna()
not_above_dz1 = surv_dz1[surv_dz1['exit_type'] != 'above']['Z_1_change_in_zone'].dropna()

if len(above_dz1) > 2 and len(not_above_dz1) > 2:
    t, p = stats.ttest_ind(above_dz1, not_above_dz1, equal_var=False)
    print(f"ΔZ₁ during transit: crossers={above_dz1.mean():.3f}, non-crossers={not_above_dz1.mean():.3f}, p={p:.4f}")
    print(f"  Crossers aged by {above_dz1.mean():.3f} Z₁ units during transit")
    print(f"  Non-crossers aged by {not_above_dz1.mean():.3f} Z₁ units")

# ═══════════════════════════════════════════════════════════════════════
# 3f. KAOPEN threshold: is there a minimum for crossing?
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("3f. INSTITUTIONAL THRESHOLD: MINIMUM KAOPEN FOR CROSSING")
print("=" * 70)

crossers = surv[surv['event'] == 1]
non_crossers = surv[surv['event'] == 0]

if crossers['kaopen_entry'].notna().sum() > 5:
    min_kao_crosser = crossers['kaopen_entry'].dropna().min()
    pct_10_crosser = crossers['kaopen_entry'].dropna().quantile(0.10)
    print(f"Minimum KAOPEN among crossers: {min_kao_crosser:.2f}")
    print(f"10th percentile KAOPEN among crossers: {pct_10_crosser:.2f}")

    # What fraction of non-crossers are below this threshold?
    nc_below = (non_crossers['kaopen_entry'].dropna() < min_kao_crosser).mean()
    print(f"Fraction of non-crossers below crosser minimum: {nc_below*100:.1f}%")

    # KAOPEN distribution
    print(f"\nKAOPEN distribution:")
    for q in [0.1, 0.25, 0.5, 0.75, 0.9]:
        c_val = crossers['kaopen_entry'].dropna().quantile(q)
        nc_val = non_crossers['kaopen_entry'].dropna().quantile(q)
        print(f"  p{int(q*100):2d}: crossers={c_val:.2f}, non-crossers={nc_val:.2f}")

# Save KM data for plotting
km_data = []
for tercile in ['Young', 'Middle', 'Old']:
    sub = surv_z1[surv_z1['z1_tercile'] == tercile]
    times, survival = kaplan_meier(sub['duration'].values, sub['event'].values)
    for t, s in zip(times, survival):
        km_data.append({'tercile': tercile, 'time': t, 'survival': s})

pd.DataFrame(km_data).to_csv(
    '/mnt/c/demographics_capital_flows/development_threshold/output/tables/table10_kaplan_meier.csv', index=False)

print("\n✓ Phase 3 complete.")
